#!/usr/bin/env python3
"""
get_flip_counts.py  —  fast version

Simulates forward tick sequences on a 3-state flip network {IN, CS, ON}
for each context index n, and writes:

  - results/flip_counts_summary.csv       (counts per n for this seed)
  - results/flip_rates_by_context.csv     (row-stochastic rates per n for this seed)
  - results/flip_counts_raw.csv           (optional, if --log_raw)

Key fix: one RNG per n (derived from base seed) — no per-step RNG creation.
"""

from __future__ import annotations
import argparse, csv, hashlib
from pathlib import Path
from typing import Dict, List
import numpy as np
import pandas as pd

STATE_IN, STATE_CS, STATE_ON = 0, 1, 2
STATE_NAMES = {0: "IN", 1: "CS", 2: "ON"}

# ---------- helpers ----------

def _load_n_list(repo_root: Path, nmin: float, nmax: float, step: float, use_dfile: bool=True) -> List[float]:
    """Prefer D_values.csv (column 'n'); else build inclusive range [nmin, nmax] by step."""
    if use_dfile:
        dfile = repo_root / "D_values.csv"
        if dfile.exists():
            df = pd.read_csv(dfile)
            if "n" in df.columns:
                return [float(x) for x in df["n"].tolist()]
    vals, n = [], nmin
    while n <= nmax + 1e-12:
        vals.append(round(n, 4))
        n += step
    return vals

def _load_D_map(repo_root: Path) -> Dict[float, float]:
    dfile = repo_root / "D_values.csv"
    if dfile.exists():
        df = pd.read_csv(dfile)
        if {"n","D"}.issubset(df.columns):
            return {float(r["n"]): float(r["D"]) for _, r in df.iterrows()}
    return {}

def g_pivot_normalized(D: float) -> float:
    """
    Policy-consistent placeholder: g:[1,3]->[0,1], g(2)=1, non-increasing for D>=2.
    Piecewise-linear, clipped:
      if D >= 2: g = max(0, 1 - (D-2))
      if D <  2: g = 1
    """
    return max(0.0, 1.0 - (D - 2.0)) if D >= 2.0 else 1.0

def build_P(D: float|None, mode: str, alpha: float, beta: float, gamma: float) -> np.ndarray:
    """
    3x3 row-stochastic transition matrix for [IN, CS, ON].

      IN: [1-a, a,   0]
      CS: [b/2,1-b,b/2]
      ON: [0,   g, 1-g]

    mode='constant' uses (alpha,beta,gamma)
    mode='Dscaled' scales them by s(D)=g_pivot_normalized(D)
    """
    if mode.lower() == "dscaled" and D is not None:
        s = g_pivot_normalized(D)
        a = np.clip(alpha * s, 0.0, 1.0)
        b = np.clip(beta  * s, 0.0, 1.0)
        g = np.clip(gamma * s, 0.0, 1.0)
    else:
        a, b, g = alpha, beta, gamma
    P = np.array([
        [1.0 - a, a,         0.0],
        [b/2.0,   1.0 - b,   b/2.0],
        [0.0,     g,         1.0 - g],
    ], dtype=float)
    P /= P.sum(axis=1, keepdims=True)  # ensure row-stochastic
    return P

def _derive_seed(base_seed: int, n: float) -> int:
    """Stable per-n seed from (base_seed, n), to make each context independent & reproducible."""
    h = hashlib.sha256(f"{base_seed}|{repr(n)}".encode("utf-8")).digest()
    return int.from_bytes(h[:8], "little", signed=False)  # 64-bit int is fine for default_rng

# ---------- main ----------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--seed",  type=int, required=True)
    ap.add_argument("--steps", type=int, required=True)
    ap.add_argument("--nmin",  type=float, default=-3)
    ap.add_argument("--nmax",  type=float, default=3)
    ap.add_argument("--nstep", type=float, default=1.0)
    ap.add_argument("--mode", choices=["constant","Dscaled"], default="constant")
    ap.add_argument("--alpha", type=float, default=0.20, help="IN->CS base rate")
    ap.add_argument("--beta",  type=float, default=0.30, help="CS split rate to (IN,ON)")
    ap.add_argument("--gamma", type=float, default=0.20, help="ON->CS base rate")
    ap.add_argument("--start_state", choices=["IN","CS","ON"], default="CS")
    ap.add_argument("--repo_root", default=".", help="where D_values.csv lives")
    ap.add_argument("--out_raw",     default="results/flip_counts_raw.csv")
    ap.add_argument("--out_summary", default="results/flip_counts_summary.csv")
    ap.add_argument("--out_rates",   default="results/flip_rates_by_context.csv")
    ap.add_argument("--log_raw", action="store_true")
    ap.add_argument("--progress", type=int, default=0, help="print every N steps (per n), 0=off")
    args = ap.parse_args()

    repo_root = Path(args.repo_root).resolve()
    Path(args.out_summary).parent.mkdir(parents=True, exist_ok=True)

    n_list = _load_n_list(repo_root, args.nmin, args.nmax, args.nstep, use_dfile=True)
    D_map  = _load_D_map(repo_root)
    start  = {"IN":STATE_IN, "CS":STATE_CS, "ON":STATE_ON}[args.start_state]

    # Optional raw log (beware size if steps*len(n) is large)
    raw_writer = None
    if args.log_raw:
        if args.steps * max(1, len(n_list)) > 2_000_000:
            print("WARNING: --log_raw with large steps/contexts will produce a very large file.")
        fraw = open(args.out_raw, "w", newline="", encoding="utf-8")
        raw_writer = csv.writer(fraw)
        raw_writer.writerow(["seed","n","step","prev_state","next_state"])

    rows_sum, rows_rates = [], []
    eps = 1e-12

    for n in n_list:
        Dn = D_map.get(n, None)
        P  = build_P(Dn, args.mode, args.alpha, args.beta, args.gamma)
        rng = np.random.default_rng(_derive_seed(args.seed, n))  # <<< one RNG per context

        counts = np.zeros((3,3), dtype=np.int64)
        s = start
        for step in range(args.steps):
            nxt = rng.choice(3, p=P[s])
            counts[s, nxt] += 1
            if raw_writer:
                raw_writer.writerow([args.seed, n, step, STATE_NAMES[s], STATE_NAMES[nxt]])
            s = nxt
            if args.progress and (step + 1) % args.progress == 0:
                print(f"n={n} step={step+1}/{args.steps}")

        # counts summary
        row = {
            "seed": args.seed, "n": n,
            "total_ticks": int(counts.sum()),
            "count_IN_to_CS": int(counts[STATE_IN, STATE_CS]),
            "count_CS_to_ON": int(counts[STATE_CS, STATE_ON]),
            "count_ON_to_CS": int(counts[STATE_ON, STATE_CS]),
            "count_CS_to_IN": int(counts[STATE_CS, STATE_IN]),
            "count_IN_to_IN": int(counts[STATE_IN, STATE_IN]),
            "count_CS_to_CS": int(counts[STATE_CS, STATE_CS]),
            "count_ON_to_ON": int(counts[STATE_ON, STATE_ON]),
        }
        rows_sum.append(row)

        # rates (row-stochastic)
        out_IN = counts[STATE_IN].sum()
        out_CS = counts[STATE_CS].sum()
        out_ON = counts[STATE_ON].sum()
        rates = {
            "seed": args.seed, "n": n,
            "rate_IN_to_CS": counts[STATE_IN, STATE_CS] / max(eps, out_IN),
            "rate_CS_to_ON": counts[STATE_CS, STATE_ON] / max(eps, out_CS),
            "rate_ON_to_CS": counts[STATE_ON, STATE_CS] / max(eps, out_ON),
            "rate_CS_to_IN": counts[STATE_CS, STATE_IN] / max(eps, out_CS),
            "rate_IN_to_IN": counts[STATE_IN, STATE_IN] / max(eps, out_IN),
            "rate_CS_to_CS": counts[STATE_CS, STATE_CS] / max(eps, out_CS),
            "rate_ON_to_ON": counts[STATE_ON, STATE_ON] / max(eps, out_ON),
            "rowsum_IN":  counts[STATE_IN].sum() / max(eps, out_IN),
            "rowsum_CS":  counts[STATE_CS].sum() / max(eps, out_CS),
            "rowsum_ON":  counts[STATE_ON].sum() / max(eps, out_ON),
        }
        rows_rates.append(rates)

    pd.DataFrame(rows_sum).to_csv(args.out_summary, index=False)
    pd.DataFrame(rows_rates).to_csv(args.out_rates, index=False)
    if raw_writer:
        fraw.close()

    print(f"Wrote {args.out_summary} and {args.out_rates}")

if __name__ == "__main__":
    main()
